import gc
import math
import os
import random
from typing import Any, Dict, List, Tuple, Union

import torch
import torch.nn.functional as F
from omegaconf import ListConfig
from sgm.modules import UNCONDITIONAL_CONFIG
from sgm.modules.autoencoding.temporal_ae import VideoDecoder
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from sgm.util import (default, disabled_train, get_obj_from_str,
                      instantiate_from_config, log_txt_as_img)
from torch import nn

from sat import mpu
from sat.helpers import print_rank0
from sat.model.finetune.lora2 import merge_linear_lora


class SATVideoDiffusionEngine(nn.Module):

    def __init__(self, args, **kwargs):
        super().__init__()

        model_config = args.model_config
        # model args preprocess
        log_keys = model_config.get('log_keys', None)
        input_key = model_config.get('input_key', 'mp4')
        network_config = model_config.get('network_config', None)
        network_wrapper = model_config.get('network_wrapper', None)
        denoiser_config = model_config.get('denoiser_config', None)
        sampler_config = model_config.get('sampler_config', None)
        conditioner_config = model_config.get('conditioner_config', None)
        first_stage_config = model_config.get('first_stage_config', None)
        loss_fn_config = model_config.get('loss_fn_config', None)
        scale_factor = model_config.get('scale_factor', 1.0)
        latent_input = model_config.get('latent_input', False)
        disable_first_stage_autocast = model_config.get(
            'disable_first_stage_autocast', False)
        no_cond_log = model_config.get('disable_first_stage_autocast', False)
        not_trainable_prefixes = model_config.get(
            'not_trainable_prefixes', ['first_stage_model', 'conditioner'])
        compile_model = model_config.get('compile_model', False)
        en_and_decode_n_samples_a_time = model_config.get(
            'en_and_decode_n_samples_a_time', None)
        lr_scale = model_config.get('lr_scale', None)
        lora_train = model_config.get('lora_train', False)
        self.use_pd = model_config.get('use_pd',
                                       False)  # progressive distillation

        self.log_keys = log_keys
        self.input_key = input_key
        self.not_trainable_prefixes = not_trainable_prefixes
        self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
        self.lr_scale = lr_scale
        self.lora_train = lora_train
        self.noised_image_input = model_config.get('noised_image_input', False)
        self.noised_image_all_concat = model_config.get(
            'noised_image_all_concat', False)
        self.noised_image_dropout = model_config.get('noised_image_dropout',
                                                     0.0)
        if args.fp16:
            dtype = torch.float16
            dtype_str = 'fp16'
        elif args.bf16:
            dtype = torch.bfloat16
            dtype_str = 'bf16'
        else:
            dtype = torch.float32
            dtype_str = 'fp32'
        self.dtype = dtype
        self.dtype_str = dtype_str

        network_config['params']['dtype'] = dtype_str
        model = instantiate_from_config(network_config)
        self.model = get_obj_from_str(
            default(network_wrapper,
                    OPENAIUNETWRAPPER))(model,
                                        compile_model=compile_model,
                                        dtype=dtype)

        self.denoiser = instantiate_from_config(denoiser_config)
        self.sampler = instantiate_from_config(
            sampler_config) if sampler_config is not None else None
        self.conditioner = instantiate_from_config(
            default(conditioner_config, UNCONDITIONAL_CONFIG))

        self._init_first_stage(first_stage_config)

        self.loss_fn = instantiate_from_config(
            loss_fn_config) if loss_fn_config is not None else None

        self.latent_input = latent_input
        self.scale_factor = scale_factor
        self.disable_first_stage_autocast = disable_first_stage_autocast
        self.no_cond_log = no_cond_log
        self.device = args.device

    def disable_untrainable_params(self):
        total_trainable = 0
        for n, p in self.named_parameters():
            if p.requires_grad == False:
                continue
            flag = False
            for prefix in self.not_trainable_prefixes:
                if n.startswith(prefix) or prefix == 'all':
                    flag = True
                    break

            lora_prefix = ['matrix_A', 'matrix_B']
            for prefix in lora_prefix:
                if prefix in n:
                    flag = False
                    break

            if flag:
                p.requires_grad_(False)
            else:
                total_trainable += p.numel()

        print_rank0('***** Total trainable parameters: ' +
                    str(total_trainable) + ' *****')

    def reinit(self, parent_model=None):
        # reload the initial params from previous trained modules
        # you can also get access to other mixins through parent_model.get_mixin().
        pass

    def merge_lora(self):

        for m in self.model.diffusion_model.mixins.adaln_layer.adaLN_modulations:
            m[1] = merge_linear_lora(m[1])

    def _init_first_stage(self, config):
        model = instantiate_from_config(config).eval()
        model.train = disabled_train
        for param in model.parameters():
            param.requires_grad = False
        self.first_stage_model = model

    def forward(self, x, batch):
        loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x,
                            batch)
        loss_mean = loss.mean()
        loss_dict = {'loss': loss_mean}
        return loss_mean, loss_dict

    def add_noise_to_first_frame(self, image):
        sigma = torch.normal(mean=-3.0, std=0.5,
                             size=(image.shape[0], )).to(self.device)
        sigma = torch.exp(sigma).to(image.dtype)
        image_noise = torch.randn_like(image) * sigma[:, None, None, None,
                                                      None]
        image = image + image_noise
        return image

    @torch.no_grad()
    def save_memory_encode_first_stage(self, x, batch):
        num_frames = x.shape[2]
        splits_x = torch.split(x, [13, 12, 12, 12], dim=2)
        all_out = []

        with torch.autocast('cuda', enabled=False):
            for idx, input_x in enumerate(splits_x):
                if idx == len(splits_x) - 1:
                    clear_fake_cp_cache = True
                else:
                    clear_fake_cp_cache = False
                out = self.first_stage_model.encode(
                    input_x.contiguous(),
                    clear_fake_cp_cache=clear_fake_cp_cache)
                all_out.append(out)

        z = torch.cat(all_out, dim=2)
        z = 1.15258426 * z
        return z

    def shared_step(self, batch: Dict) -> Any:
        x = self.get_input(batch)
        #      print(f"this is iteration {self.share_cache['iteration']}", flush=True)
        #     print(f'''{"train_size_range" in self.share_cache}''', flush=True)
        if 'train_size_range' in self.share_cache:
            train_size_range = self.share_cache.get('train_size_range')
            size_factor = random.uniform(*train_size_range)
            # broadcast the size factor from rank 0
            size_factor = torch.tensor(size_factor).to(self.device)
            torch.distributed.broadcast(size_factor,
                                        src=0,
                                        group=mpu.get_data_parallel_group())
            #  print(f"size_factor: {size_factor} at rank : {torch.distributed.get_rank()}", flush=True)
            target_size = (int(x.shape[3] * size_factor),
                           int(x.shape[4] * size_factor))
            #  print(target_size)
            # make sure it can be divided by 16
            b, t, c, h, w = x.shape
            # reshape to  b * t, c, h, w
            x = x.reshape(b * t, c, h, w)
            target_size = (target_size[0] // 16 * 16,
                           target_size[1] // 16 * 16)
            x = F.interpolate(x,
                              size=target_size,
                              mode='bilinear',
                              align_corners=False,
                              antialias=True)
            # reshape back to b, t, c, h, w
            x = x.reshape(b, t, c, target_size[0], target_size[1])

        if self.lr_scale is not None:
            lr_x = F.interpolate(x,
                                 scale_factor=1 / self.lr_scale,
                                 mode='bilinear',
                                 align_corners=False)
            lr_x = F.interpolate(lr_x,
                                 scale_factor=self.lr_scale,
                                 mode='bilinear',
                                 align_corners=False)
            lr_z = self.encode_first_stage(lr_x, batch)
            batch['lr_input'] = lr_z

        x = x.permute(0, 2, 1, 3, 4).contiguous()
        if self.noised_image_input:
            image = x[:, :, 0:1]
            image = self.add_noise_to_first_frame(image)
            image = self.encode_first_stage(image, batch)
        b, c, t, h, w = x.shape
        if t == 49 and (h * w) > 480 * 720:
            if os.environ.get('DEBUGINFO', None) is not None:
                print(
                    f'save memory encode first stage with in shape {x.shape}, {x.mean()}'
                )
            x = self.save_memory_encode_first_stage(x, batch)
        else:
            x = self.encode_first_stage(x, batch)

    #  x = self.encode_first_stage(x, batch)
        x = x.permute(0, 2, 1, 3, 4).contiguous()

        if 'ref_mp4' in self.share_cache:
            if not 'disable_ref' in self.share_cache:
                ref_mp4 = self.share_cache.pop('ref_mp4')
                ref_mp4 = ref_mp4.to(self.dtype).to(self.device)
                ref_mp4 = ref_mp4.permute(0, 2, 1, 3, 4).contiguous()
                ref_x = self.encode_first_stage(ref_mp4, batch)
                ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()
                self.share_cache['ref_x'] = ref_x

        if self.noised_image_input:
            image = image.permute(0, 2, 1, 3, 4).contiguous()
            if self.noised_image_all_concat:
                image = image.repeat(1, x.shape[1], 1, 1, 1)
            else:
                image = torch.concat([image, torch.zeros_like(x[:, 1:])],
                                     dim=1)
            if random.random() < self.noised_image_dropout:
                image = torch.zeros_like(image)
            batch['concat_images'] = image

        # gc.collect()
        # torch.cuda.empty_cache()
        loss, loss_dict = self(x, batch)
        return loss, loss_dict

    def get_input(self, batch):
        return batch[self.input_key].to(self.dtype)

    @torch.no_grad()
    def decode_first_stage(self, z):
        z = 1.0 / self.scale_factor * z
        n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
        n_rounds = math.ceil(z.shape[0] / n_samples)
        all_out = []
        with torch.autocast('cuda',
                            enabled=not self.disable_first_stage_autocast):
            for n in range(n_rounds):
                if isinstance(self.first_stage_model.decoder, VideoDecoder):
                    kwargs = {
                        'timesteps': len(z[n * n_samples:(n + 1) * n_samples])
                    }
                else:
                    kwargs = {}
                out = self.first_stage_model.decode(
                    z[n * n_samples:(n + 1) * n_samples], **kwargs)
                all_out.append(out)
        out = torch.cat(all_out, dim=0)
        return out

    @torch.no_grad()
    def encode_first_stage(self, x, batch):
        frame = x.shape[2]

        if frame > 1 and self.latent_input:
            x = x.permute(0, 2, 1, 3, 4).contiguous()
            return x * self.scale_factor  # already encoded

        n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
        n_rounds = math.ceil(x.shape[0] / n_samples)
        all_out = []
        with torch.autocast('cuda',
                            enabled=not self.disable_first_stage_autocast):
            for n in range(n_rounds):
                out = self.first_stage_model.encode(x[n * n_samples:(n + 1) *
                                                      n_samples])
                all_out.append(out)
        z = torch.cat(all_out, dim=0)
        z = self.scale_factor * z
        return z

    @torch.no_grad()
    def sample(
        self,
        cond: Dict,
        uc: Union[Dict, None] = None,
        batch_size: int = 16,
        shape: Union[None, Tuple, List] = None,
        prefix=None,
        concat_images=None,
        **kwargs,
    ):
        randn = torch.randn(batch_size,
                            *shape).to(torch.float32).to(self.device)
        if hasattr(self, 'seeded_noise'):
            randn = self.seeded_noise(randn)

        if prefix is not None:
            randn = torch.cat([prefix, randn[:, prefix.shape[1]:]], dim=1)

        # broadcast noise
        mp_size = mpu.get_model_parallel_world_size()
        if mp_size > 1:
            global_rank = torch.distributed.get_rank() // mp_size
            src = global_rank * mp_size
            torch.distributed.broadcast(randn,
                                        src=src,
                                        group=mpu.get_model_parallel_group())

        scale = None
        scale_emb = None

        denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser(
            self.model,
            input,
            sigma,
            c,
            concat_images=concat_images,
            **addtional_model_inputs)

        if 'cfg' in self.share_cache:
            self.sampler.guider.scale = self.share_cache['cfg']
            print('overwrite cfg scale in config of stage-1')

        samples = self.sampler(denoiser,
                               randn,
                               cond,
                               uc=uc,
                               scale=scale,
                               scale_emb=scale_emb,
                               num_steps=kwargs.get('num_steps', None))
        samples = samples.to(self.dtype)
        return samples

    @torch.no_grad()
    def log_conditionings(self, batch: Dict, n: int) -> Dict:
        """
        Defines heuristics to log different conditionings.
        These can be lists of strings (text-to-image), tensors, ints, ...
        """
        image_h, image_w = batch[self.input_key].shape[3:]
        log = dict()

        for embedder in self.conditioner.embedders:
            if ((self.log_keys is None) or
                (embedder.input_key
                 in self.log_keys)) and not self.no_cond_log:
                x = batch[embedder.input_key][:n]
                if isinstance(x, torch.Tensor):
                    if x.dim() == 1:
                        # class-conditional, convert integer to string
                        x = [str(x[i].item()) for i in range(x.shape[0])]
                        xc = log_txt_as_img((image_h, image_w),
                                            x,
                                            size=image_h // 4)
                    elif x.dim() == 2:
                        # size and crop cond and the like
                        x = [
                            'x'.join([str(xx) for xx in x[i].tolist()])
                            for i in range(x.shape[0])
                        ]
                        xc = log_txt_as_img((image_h, image_w),
                                            x,
                                            size=image_h // 20)
                    else:
                        raise NotImplementedError()
                elif isinstance(x, (List, ListConfig)):
                    if isinstance(x[0], str):
                        xc = log_txt_as_img((image_h, image_w),
                                            x,
                                            size=image_h // 20)
                    else:
                        raise NotImplementedError()
                else:
                    raise NotImplementedError()
                log[embedder.input_key] = xc
        return log

    @torch.no_grad()
    def log_video(
        self,
        batch: Dict,
        N: int = 8,
        ucg_keys: List[str] = None,
        only_log_video_latents=False,
        **kwargs,
    ) -> Dict:
        conditioner_input_keys = [
            e.input_key for e in self.conditioner.embedders
        ]
        if ucg_keys:
            assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
                'Each defined ucg key for sampling must be in the provided conditioner input keys,'
                f'but we have {ucg_keys} vs. {conditioner_input_keys}')
        else:
            ucg_keys = conditioner_input_keys
        log = dict()

        x = self.get_input(batch)

        c, uc = self.conditioner.get_unconditional_conditioning(
            batch,
            force_uc_zero_embeddings=ucg_keys
            if len(self.conditioner.embedders) > 0 else [],
        )

        sampling_kwargs = {}

        N = min(x.shape[0], N)
        x = x.to(self.device)[:N]
        if not self.latent_input:
            log['inputs'] = x.to(torch.float32)
        x = x.permute(0, 2, 1, 3, 4).contiguous()
        z = self.encode_first_stage(x, batch)
        if not only_log_video_latents:
            log['reconstructions'] = self.decode_first_stage(z).to(
                torch.float32)
            log['reconstructions'] = log['reconstructions'].permute(
                0, 2, 1, 3, 4).contiguous()
        z = z.permute(0, 2, 1, 3, 4).contiguous()

        log.update(self.log_conditionings(batch, N))

        for k in c:
            if isinstance(c[k], torch.Tensor):
                c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))

        if self.noised_image_input:
            image = x[:, :, 0:1]
            image = self.add_noise_to_first_frame(image)
            image = self.encode_first_stage(image, batch)
            image = image.permute(0, 2, 1, 3, 4).contiguous()
            image = torch.concat([image, torch.zeros_like(z[:, 1:])], dim=1)
            c['concat'] = image
            uc['concat'] = image
            samples = self.sample(c,
                                  shape=z.shape[1:],
                                  uc=uc,
                                  batch_size=N,
                                  **sampling_kwargs)  # b t c h w
            samples = samples.permute(0, 2, 1, 3, 4).contiguous()
            if only_log_video_latents:
                latents = 1.0 / self.scale_factor * samples
                log['latents'] = latents
            else:
                samples = self.decode_first_stage(samples).to(torch.float32)
                samples = samples.permute(0, 2, 1, 3, 4).contiguous()
                log['samples'] = samples
        else:
            samples = self.sample(c,
                                  shape=z.shape[1:],
                                  uc=uc,
                                  batch_size=N,
                                  **sampling_kwargs)  # b t c h w
            samples = samples.permute(0, 2, 1, 3, 4).contiguous()
            if only_log_video_latents:
                latents = 1.0 / self.scale_factor * samples
                log['latents'] = latents
            else:
                samples = self.decode_first_stage(samples).to(torch.float32)
                samples = samples.permute(0, 2, 1, 3, 4).contiguous()
                log['samples'] = samples
        return log


class SATUpscalerEngine(SATVideoDiffusionEngine):

    def shared_step(self, batch: Dict) -> Any:
        x = self.get_input(batch)
        if self.lr_scale is not None:
            lr_x = F.interpolate(x,
                                 scale_factor=1 / self.lr_scale,
                                 mode='bilinear',
                                 align_corners=False)
            lr_x = F.interpolate(lr_x,
                                 scale_factor=self.lr_scale,
                                 mode='bilinear',
                                 align_corners=False)
            lr_z = self.encode_first_stage(lr_x, batch)
            batch['lr_input'] = lr_z

        x = x.permute(0, 2, 1, 3, 4).contiguous()
        if self.noised_image_input:
            image = x[:, :, 0:1]
            image = self.add_noise_to_first_frame(image)
            image = self.encode_first_stage(image, batch)

        x = self.encode_first_stage(x, batch)
        x = x.permute(0, 2, 1, 3, 4).contiguous()

        if 'ref_mp4' in self.share_cache:
            ref_mp4 = self.share_cache.pop('ref_mp4')
            ref_mp4 = ref_mp4.to(self.dtype).to(self.device)
            ref_mp4 = ref_mp4.permute(0, 2, 1, 3, 4).contiguous()
            ref_x = self.encode_first_stage(ref_mp4, batch)
            ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()
            self.share_cache['ref_x'] = ref_x

        if self.noised_image_input:
            image = image.permute(0, 2, 1, 3, 4).contiguous()
            if self.noised_image_all_concat:
                image = image.repeat(1, x.shape[1], 1, 1, 1)
            else:
                image = torch.concat([image, torch.zeros_like(x[:, 1:])],
                                     dim=1)
            if random.random() < self.noised_image_dropout:
                image = torch.zeros_like(image)
            batch['concat_images'] = image

        ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()
        ref_x = self.first_stage_model.decoder(ref_x)
        ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()
        loss_mean = torch.mean(((x - ref_x)**2).reshape(x.shape[0], -1), 1)
        loss_mean = loss_mean.mean()
        loss_dict = {'loss': loss_mean}

        return loss_mean, loss_dict

    def disable_untrainable_params(self):
        pass

    # def forward(self, x, batch):
    #     loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
    #     loss_mean = loss.mean()
    #     loss_dict = {"loss": loss_mean}
    #     return loss_mean, loss_dict
